import anndata
import itertools
import numpy as np
import pandas as pd
import sys
from enformer_pytorch.data import str_to_one_hot
import scipy.signal
import multiprocessing as mp


def scan_motifs(seq, seq_id, pwms):
    out = []
    seq = str_to_one_hot(seq).numpy().T
    for row in pwms.itertuples():        
        sites = list(np.where(scipy.signal.convolve2d(seq, row.kernel, mode='valid') > row.threshold)[1])
        if len(sites) > 0:
            out.append([seq_id, row.Index, sites])
    return out


def scan_seqs_motifs(df, pwm_df, num_workers=32):
    with mp.Pool(num_workers) as pool:
        items = zip(df.Sequence.tolist(), df.index.tolist(), [pwm_df]*len(df))
        res = pool.starmap(scan_motifs, items)
        res = list(itertools.chain.from_iterable(res))
        res = pd.DataFrame(res)
        res.columns = ['seq_id', 'Matrix_id', 'pos']
        res.Matrix_id = res.Matrix_id.astype(str)
        res.pos = res.pos.apply(lambda x: [int(p) for p in x])
        return res.explode('pos').set_index('seq_id')


def calculate_motif_counts(sites_df):
    return anndata.AnnData(pd.pivot_table(sites_df, values='pos', index='seq_id', columns='Matrix_id',
               aggfunc='count').fillna(0))
